import torch
import torch.nn as nn

class my_MSELoss(nn.Module):
    def __init__(self):
        super(my_MSELoss, self).__init__()

    def forward(self, y_true, y_pred):
        return torch.mean((y_true - y_pred) ** 2)